/**
 * @file MeanShift.cpp
 * @author enemy1205 (enemy1205@qq.com)
 * @date 2021-09-02
 */
#include "MeanShift.h"
using namespace std;
using namespace cv;
/**
 * @brief 标准正态分布概率密度函数
 * @param x 传入值
 * @return 概率密度
 */
double Guass(double x) {
    double exponent = -x * x / 2.0;//指数
    return (1.0 / (sqrt(2 * PI))) * exp(exponent);
}

/**
 * @brief 构造函数初始化
 * @param path 图片路径
 */
MeanShift::MeanShift(string &path) {
    this->src = imread(path);
    cvtColor(src, img_Luv, COLOR_BGR2Luv);
    this->filt=src.clone();
    this->Label = Mat::zeros(this->src.rows, this->src.cols, CV_16UC1);
    this->region = 0;
}

/**
 * @brief 均值偏移滤波平滑像素点
 * @param sr 搜索迭代球半径
 * @param tr 阈值
 */
void MeanShift::meanshiftFilter(const double sr = 6, const double tr = 8) {
    assert(tr > 0);
    for (int i = 0; i < src.rows; ++i) {
        auto dst_rows = filt.ptr<Vec3b>(i);
        for (int j = 0; j < src.cols; ++j) {
            double cur_row = i;//迭代过程中重心的行坐标
            double cur_col = j;//迭代过程中重心的列坐标
            auto cur_color = src.ptr<Vec3b>(i)[j];//迭代过程中重心的像素值
            //在max_itr次数里面不断更新重心
            for (int k = 0; k < this->max_itr; ++k) {
                double ltr = (cur_row - sr - 1);// 左上顶点行次  +1?
                double ltc = (cur_col - sr - 1);// 左上顶点列次   +1?
                double rbr = (cur_row + sr + 1);// 右下顶点行次
                double rbc = (cur_col + sr + 1);// 右下顶点列次
                //碰到边界时按边界来算
                ltr = ltr < 0 ? 0 : ltr;
                ltc = ltc < 0 ? 0 : ltc;
                rbr = rbr > src.rows ? src.rows : rbr;
                rbc = rbc > src.cols ? src.cols : rbc;

                double displacement_r = 0;//行替换值
                double displacement_c = 0;//列替换值
                double denominator = 0;//分母

                //首次遍历时以球中心为重心
                for (auto l = static_cast<unsigned long>(ltr); l < rbr; l++) {
                    for (auto m = static_cast<unsigned long>(ltc); m < rbc; m++)
                        //l,m需要带入指针索引像素值，故需整形
                    {   //计算每个点与中心点像素坐标位置上的偏移
                        double temp = cur_row - l;
                        double distant_sq = 0;
                        distant_sq += (temp * temp);
                        temp = cur_col - m;
                        distant_sq += (temp * temp);
                        //上述给定区域为正方形，实际只算球内离散(由于直接用球坐标不易处理)
                        if (distant_sq <= sr * sr) {
                            double color_distance = 0;
                            //计算球区域内每个点与中心点像素值的偏移
                            for (int c = 0; c < 3; c++) {
                                temp = src.ptr<Vec3b>(l)[m][c] - cur_color[c];
                                color_distance += (temp * temp);
                            }
                            color_distance = sqrt(color_distance);

                            double weight = Guass(color_distance / tr);//计算权重,颜色越相近权重越小
                            displacement_r += (l - cur_row) * weight;
                            displacement_c += (m - cur_col) * weight;
                            denominator += weight;
                        }
                    }
                }
                double dr = displacement_r / denominator;//计算此次迭代中的row方向的偏移
                double dc = displacement_c / denominator;//计算此次迭代中的col方向的偏移
                cur_row += dr;
                cur_col += dc;
                Vec<uchar, 3> pre_color = cur_color;//储存上一次迭代中的像素值
                //倘若出现超出边界情况
                if (cur_row < 0 || cur_row >= src.rows - 1 || cur_col < 0 || cur_col >= src.cols - 1) {
                    int cur_r = cur_row;
                    int cur_c = cur_col;
                    //偏移超出边界后即边界
                    cur_r = cur_r < 0 ? 0 : cur_r;
                    cur_c = cur_c < 0 ? 0 : cur_c;
                    cur_r = cur_r >= src.rows - 1 ? src.rows - 1 : cur_r;
                    cur_c = cur_c >= src.cols - 1 ? src.cols - 1 : cur_c;
                    //减少自定义变量
//                        cur_row = cur_row < 0 ? 0 : cur_row;
//                        cur_col = cur_col < 0 ? 0 : cur_col;
//                        cur_row = cur_row >= src.size() ? src.size() - 1 : cur_row;
//                        cur_col = cur_col >= src[i].size() ? src[i].size() - 1 : cur_col;
                    //重心像素值直接以边界点更新
                    cur_color = src.ptr<Vec3b>(cur_r)[cur_c];
                } else//对新重心位置坐标不做任何变化
                {
                    for (int c = 0; c < 3; c++) {//根据位置距离权重更新重心点像素值
                        int lt_r = cur_row, lt_c = cur_col;//由于进行过取整操作,cur_...>lt_.  艹
                        double temp1 = src.ptr<Vec3b>(lt_r)[lt_c][c] + (cur_col - lt_c) *
                                                                       (src.ptr<Vec3b>(lt_r)[lt_c + 1][c] -
                                                                        src.ptr<Vec3b>(lt_r)[lt_c][c]);
                        double temp2 = src.ptr<Vec3b>(lt_r + 1)[lt_c][c] + (cur_col - lt_c) *
                                                                           (src.ptr<Vec3b>(lt_r + 1)[lt_c + 1][c] -
                                                                            src.ptr<Vec3b>(lt_r + 1)[lt_c][c]);
                        cur_color[c] = temp1 + (cur_row - lt_r) * (temp2 - temp1);

                    }
                }
                // 双线性差值
                double color_distance = 0;//前后两次重心之间的像素值偏差
                for (int c = 0; c < 3; c++) {
                    double temp = pre_color[c] - cur_color[c];
                    color_distance += (temp * temp);
                }
                //当位置偏移或者像素值偏差达到精度要求或者达到迭代最大次数时，保存此次像素值作为结果
                if ((abs(dr) < EQUAL_ERR && abs(dc) < EQUAL_ERR) ||
                    color_distance < color_threshold * color_threshold || k == max_itr - 1) {
                    dst_rows[j] = cur_color;
                    // 此处可记录收敛点 (cur_row,cur_col)
                    break;
                }
            }
        }
    }
}

/**
 * @brief 为不同区域打上标签
 * @param tr 阈值
 */
void MeanShift::buildLabel(const double tr = 8) {
    Vec3b curcolor;//当前位置像素值
    Vec3b average;//区域像素值平均值
    double tr_sq = tr * tr;//阈值的平方
    stack<Point2l> S;
    for (unsigned long i = 0; i < filt.rows; i++) {
        for (unsigned long j = 0; j < filt.cols; j++) {
            auto Label_p = Label.ptr<uint16_t>(i);
            if (Label_p[j] <= 0)//倘若该点未标注，即未属于某块区域
            {
                region++;
                Label_p[j] = region;
                S.push(Point2l(i, j));
                curcolor = filt.ptr<Vec3b>(i)[j];
                num = 1;
                average = curcolor;//平均像素值

                while (true) {
                    if (S.empty()) { break; }
                    Point2l curpos = S.top();//返回堆栈的当前顶部元素
                    S.pop();//剔除栈顶值
                    //遍历中心点周围8点
                    for (int k = 0; k < 8; k++) {
                        unsigned long r = curpos.x + displacement[k][0];
                        unsigned long c = curpos.y + displacement[k][1];
                        if (r >= 0 && r < filt.rows && c >= 0 && c < filt.cols && Label.ptr<uint16_t>(r)[c] <= 0) {
                            Vec3b cl = filt.ptr<Vec3b>(r)[c];
                            double color_distance_sq = 0;//像素值偏差的平方
                            //计算周围点与中间点像素值偏差
                            for (int channel = 0; channel < 3; channel++) {
                                double temp = curcolor[channel] - cl[channel];
                                color_distance_sq += (temp * temp);
                            }
                            //倘若小于规定,周边点像素值同中心点
                            if (color_distance_sq <= tr_sq) {
                                Label.ptr<uint16_t>(r)[c] = region;
                                S.push(Point2l(r, c));//延伸至周围
                                num++;//区域符合要求的点数自增

                                for (int channel = 0; channel < 3; channel++) {
                                    average[channel] += cl[channel];
                                }
                            }
                        }
                    }
                }
                //计算出该块近似区域的像素值
                Colors.emplace_back(average[0] / num, average[1] / num, average[2] / num);
                //记录该块区域符合要求的点数
                pointnum.push_back(num);
            }
        }
    }
}

/**
 * @brief 为相邻区域创建链表
 */
 // TODO Label存在为0的像素值
void MeanShift::buildChain() {
    adjoint_table.reserve(region);//为所有划分出的区域创建相邻表
    for (int i = 0; i < region; i++) { adjoint_table.emplace_back(); }//初始化，方便下标索引
    adjoint_table_head.reserve(region);
    for (int i = 0; i < region; i++) { adjoint_table_head.push_back(i + 1); }
    int L1 = 0, L2 = 0;
    for (unsigned long i = 0; i < Label.rows; i++) {
        for (unsigned long j = 0; j < Label.cols; j++) {//由于是由上至下，由左至右遍历，故只对前行列进行对比即可
            if (i > 0 && Label.ptr<uint16_t>(i)[j] != Label.ptr<uint16_t>(i - 1)[j])//撇开首行，上下行相邻标签不同视为边界
            {
                L1 = Label.ptr<uint16_t>(i)[j], L2 = Label.ptr<uint16_t>(i - 1)[j];
                vector<int>::iterator it = find(adjoint_table[L1 - 1].begin(),
                                                adjoint_table[L1 - 1].end(), L2);//查找相邻表中是否有L2
                if (adjoint_table[L1 - 1].end() == it) {
                    adjoint_table[L1 - 1].push_back(L2);//相邻表内存入相邻区域的序号
                    adjoint_table[L2 - 1].push_back(L1);
                }
            }
            if (j > 0 && Label.ptr<uint16_t>(i)[j] != Label.ptr<uint16_t>(i)[j - 1]) {
                L1 = Label.ptr<uint16_t>(i)[j], L2 = Label.ptr<uint16_t>(i)[j - 1];
                vector<int>::iterator it = find(adjoint_table[L1 - 1].begin(),
                                                adjoint_table[L1 - 1].end(), L2);
                if (adjoint_table[L1 - 1].end() == it) {
                    adjoint_table[L1 - 1].push_back(L2);
                    adjoint_table[L2 - 1].push_back(L1);
                }
            }
        }
    }
}


/**
 * @brief 合并相邻相似区域
 * @param max_iter 最大迭代次数
 * @param tr 阈值
 */
void MeanShift::mergeSimilarity(int max_iter = 5, const double tr = 8) {
    // 合并相似区域
    for (int i = 0, deltaRegion = 1; i < max_iter && deltaRegion > 0; i++) {
        //遍历合并表 r*c=region
        for (unsigned long r = 0; r < adjoint_table.size(); r++) {
            for (unsigned long c = 0; c < adjoint_table[r].size(); c++) {
                double color_distance_sq = 0;
                Vec3b color1 = Colors[r]/*c列r行像素值,定义时是以Row定义的*/, color2 = Colors[adjoint_table[r][c] - 1]/*c列r行相邻区域*/;
                for (int channel = 0; channel < 3; channel++) {
                    double temp = color1[channel] - color2[channel];
                    color_distance_sq += (temp * temp);
                }
                //相邻区域像素值偏差值小于阈值
                if (color_distance_sq <= tr * tr) {
                    unsigned long r_root = r + 1, c_root = adjoint_table[r][c];
                    while (adjoint_table_head[r_root - 1] != r_root)r_root = adjoint_table_head[r_root - 1];
                    while (adjoint_table_head[c_root - 1] != c_root)c_root = adjoint_table_head[c_root - 1];

                    if (r_root < c_root) { adjoint_table_head[c_root - 1] = r_root; }
                    else { adjoint_table_head[r_root - 1] = c_root; }
                }
            }
        }
        for (unsigned long ind = 0; ind < adjoint_table_head.size(); ind++) {
            unsigned long ind_root = ind + 1;
            while (adjoint_table_head[ind_root - 1] != ind_root)ind_root = adjoint_table_head[ind_root - 1];
            adjoint_table_head[ind] = ind_root;
        }

        vector<Vec3b> new_Colors;
        new_Colors.reserve(region);
        for (int ind = 0; ind < region; ind++) { new_Colors.emplace_back(0, 0, 0); }
        vector<int> new_pointnum;
        new_pointnum.reserve(region);
        for (int ind = 0; ind < region; ind++) { new_pointnum.push_back(0); }
        for (unsigned long ind = 0; ind < adjoint_table_head.size(); ind++) {
            unsigned long ind_ = adjoint_table_head[ind] - 1;
            new_pointnum[ind_] += pointnum[ind];
            for (int channel = 0; channel < 3; channel++) {
                new_Colors[ind_][channel] += Colors[ind][channel] * pointnum[ind];
            }
        }

        vector<int> new_adjoint_table_head;
        new_adjoint_table_head.reserve(region);
        for (int ind = 0; ind < region; ind++) { new_adjoint_table_head.push_back(0); }
        int label = 0;
        Colors.clear();
        pointnum.clear();
        for (int ind = 0; ind < region; ind++) {
            unsigned long ind_ = adjoint_table_head[ind] - 1;
            if (new_adjoint_table_head[ind_] <= 0) {
                label++;
                new_adjoint_table_head[ind_] = label;
                num = new_pointnum[ind_];
                Vec3b cl = new_Colors[ind_];
                pointnum.push_back(num);
                Colors.emplace_back(cl[0] / num, cl[1] / num, cl[2] / num);
            }
        }
        vector<vector<int> > new_adjoint_table;
        new_adjoint_table.reserve(label);
        for (int ind = 0; ind < label; ind++) { new_adjoint_table.emplace_back(); }
        for (int ind = 0; ind < region; ind++) {
            int ind_ = new_adjoint_table_head[adjoint_table_head[ind] - 1] - 1;
            for (int j = 0; j < adjoint_table[ind].size(); j++) {
                int new_label = new_adjoint_table_head[adjoint_table_head[adjoint_table[ind][j] - 1] - 1];
                if (new_label != ind_ + 1) {
                    vector<int>::iterator it = find(new_adjoint_table[ind_].begin(),
                                                    new_adjoint_table[ind_].end(), new_label);
                    if (it == new_adjoint_table[ind_].end()) {
                        new_adjoint_table[ind_].push_back(new_label);
                    }
                }
            }
        }

        deltaRegion = region - label;
        region = label;
        adjoint_table = new_adjoint_table;
        for (unsigned long r = 0; r < Label.rows; r++) {
            for (unsigned long c = 0; c < Label.cols; c++) {
                Label.ptr<uint16_t>(r)[c] = new_adjoint_table_head[adjoint_table_head[Label.ptr<uint16_t>(r)[c] - 1] -
                                                                   1];
            }
        }
        adjoint_table_head.clear();
        adjoint_table_head.reserve(region);
        for (int ind = 0; ind < region; ind++) { adjoint_table_head.push_back(ind + 1); }
    }
}

/**
 * @brief 移除较小面积区域
 */
void MeanShift::removeSmallArea() {
    double mindist_sq;
    for (int deltaRegion = 1; deltaRegion > 0;) {
        for (unsigned long r = 0; r < adjoint_table.size(); r++) {
            mindist_sq = 1e12;  // Bigger than any possible color distance square.
            if (pointnum[r] < minarea) {
                unsigned long c_root = r + 1;
                for (unsigned long c = 0; c < adjoint_table[r].size(); c++) {
                    double color_distance_sq = 0;
                    Vec3b color1 = Colors[r], color2 = Colors[adjoint_table[r][c] - 1];
                    for (int channel = 0; channel < 3; channel++) {
                        double temp = color1[channel] - color2[channel];
                        color_distance_sq += (temp * temp);
                    }
                    if (color_distance_sq < mindist_sq) {
                        mindist_sq = color_distance_sq;
                        c_root = adjoint_table[r][c];
                    }
                }

                unsigned long r_root = r + 1;
                while (adjoint_table_head[r_root - 1] != r_root)r_root = adjoint_table_head[r_root - 1];
                while (adjoint_table_head[c_root - 1] != c_root)c_root = adjoint_table_head[c_root - 1];

                if (r_root < c_root) { adjoint_table_head[c_root - 1] = r_root; }
                else { adjoint_table_head[r_root - 1] = c_root; }
            }
        }
        for (unsigned long ind = 0; ind < adjoint_table_head.size(); ind++) {
            unsigned long ind_root = ind + 1;
            while (adjoint_table_head[ind_root - 1] != ind_root)ind_root = adjoint_table_head[ind_root - 1];
            adjoint_table_head[ind] = ind_root;
        }

        vector<Vec3b> new_Colors;
        new_Colors.reserve(region);
        for (int ind = 0; ind < region; ind++) { new_Colors.emplace_back(0, 0, 0); }
        vector<int> new_pointnum;
        new_pointnum.reserve(region);
        for (int ind = 0; ind < region; ind++) { new_pointnum.push_back(0); }
        for (unsigned long ind = 0; ind < adjoint_table_head.size(); ind++) {
            unsigned long ind_ = adjoint_table_head[ind] - 1;
            new_pointnum[ind_] += pointnum[ind];
            for (int channel = 0; channel < 3; channel++) {
                new_Colors[ind_][channel] += Colors[ind][channel] * pointnum[ind];
            }
        }

        vector<int> new_adjoint_table_head;
        new_adjoint_table_head.reserve(region);
        for (int ind = 0; ind < region; ind++) { new_adjoint_table_head.push_back(0); }
        int label = 0;
        Colors.clear();
        pointnum.clear();
        for (int ind = 0; ind < region; ind++) {
            unsigned long ind_ = adjoint_table_head[ind] - 1;
            if (new_adjoint_table_head[ind_] <= 0) {
                label++;
                new_adjoint_table_head[ind_] = label;
                num = new_pointnum[ind_];
                Vec3b cl = new_Colors[ind_];
                pointnum.push_back(num);
                Colors.emplace_back(cl[0] / num, cl[1] / num, cl[2] / num);
            }
        }
        vector<vector<int> > new_adjoint_table;
        new_adjoint_table.reserve(label);
        for (int ind = 0; ind < label; ind++) { new_adjoint_table.emplace_back(); }
        for (int ind = 0; ind < region; ind++) {
            int ind_ = new_adjoint_table_head[adjoint_table_head[ind] - 1] - 1;
            for (int j = 0; j < adjoint_table[ind].size(); j++) {
                int new_label = new_adjoint_table_head[adjoint_table_head[adjoint_table[ind][j] - 1] - 1];
                if (new_label != ind_ + 1) {
                    vector<int>::iterator it = find(new_adjoint_table[ind_].begin(),
                                                    new_adjoint_table[ind_].end(), new_label);
                    if (it == new_adjoint_table[ind_].end()) {
                        new_adjoint_table[ind_].push_back(new_label);
                    }
                }
            }
        }

        deltaRegion = region - label;
        region = label;
        adjoint_table = new_adjoint_table;
        for (unsigned long r = 0; r < Label.rows; r++) {
            for (unsigned long c = 0; c < Label.cols; c++) {
                Label.ptr<uint16_t>(r)[c] = new_adjoint_table_head[adjoint_table_head[Label.ptr<uint16_t>(r)[c] - 1] -
                                                                   1];
            }
        }
        adjoint_table_head.clear();
        adjoint_table_head.reserve(region);
        for (int ind = 0; ind < region; ind++) { adjoint_table_head.push_back(ind + 1); }
    }
}

/**
 * @brief 结果绘制
 */
void MeanShift::buildResult() {
    for (int i = 0; i < Label.rows; ++i) {
        for (int j = 0; j < Label.cols; ++j) {
            dst.ptr<Vec3b>(i)[j] = Colors[Label.ptr<uint16_t>(i)[j] - 1];
        }
    }
}

/**
 * @brief 主函数调用接口
 */
void MeanShift::meanshiftSegmentation() {
    this->meanshiftFilter();
    this->buildLabel();
    this->buildChain();
    this->mergeSimilarity();
    this->removeSmallArea();
    this->buildResult();

}

